# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###############################################################################
import torch
from torch import nn

from nemo.collections.tts.helpers.common import (
    AffineTransformationLayer,
    ConvAttention,
    Encoder,
    ExponentialClass,
    Invertible1x1Conv,
    Invertible1x1ConvLUS,
    LengthRegulator,
    LinearNorm,
    get_mask_from_lengths,
)
from nemo.collections.tts.modules.alignment import mas_width1 as mas
from nemo.collections.tts.modules.attribute_prediction_model import get_attribute_prediction_model


class FlowStep(nn.Module):
    def __init__(
        self,
        n_mel_channels,
        n_context_dim,
        n_layers,
        affine_model='simple_conv',
        scaling_fn='exp',
        matrix_decomposition='',
        affine_activation='softplus',
        use_partial_padding=False,
    ):
        super(FlowStep, self).__init__()
        if matrix_decomposition == 'LUS':
            self.invtbl_conv = Invertible1x1ConvLUS(n_mel_channels)
        else:
            self.invtbl_conv = Invertible1x1Conv(n_mel_channels)

        self.affine_tfn = AffineTransformationLayer(
            n_mel_channels,
            n_context_dim,
            n_layers,
            affine_model=affine_model,
            scaling_fn=scaling_fn,
            affine_activation=affine_activation,
            use_partial_padding=use_partial_padding,
        )

    def forward(self, z, context, inverse=False, seq_lens=None):
        if inverse:  # for inference z-> mel
            z = self.affine_tfn(z, context, inverse, seq_lens=seq_lens)
            z = self.invtbl_conv(z, inverse)
            return z
        else:  # training mel->z
            z, log_det_W = self.invtbl_conv(z)
            z, log_s = self.affine_tfn(z, context, seq_lens=seq_lens)
            return z, log_det_W, log_s


class RadTTSModule(torch.nn.Module):
    def __init__(
        self,
        n_speakers,
        n_speaker_dim,
        n_text,
        n_text_dim,
        n_flows,
        n_conv_layers_per_step,
        n_mel_channels,
        n_hidden,
        mel_encoder_n_hidden,
        dummy_speaker_embedding,
        n_early_size,
        n_early_every,
        n_group_size,
        affine_model,
        dur_model_config,
        f0_model_config,
        energy_model_config,
        v_model_config=None,
        include_modules='dec',
        scaling_fn='exp',
        matrix_decomposition='',
        learn_alignments=False,
        affine_activation='softplus',
        attn_use_CTC=True,
        use_context_lstm=False,
        context_lstm_norm=None,
        text_encoder_lstm_norm=None,
        n_f0_dims=0,
        n_energy_avg_dims=0,
        context_lstm_w_f0_and_energy=True,
        use_first_order_features=False,
        unvoiced_bias_activation='',
        ap_pred_log_f0=False,
        **kwargs
    ):
        super(RadTTSModule, self).__init__()
        assert n_early_size % 2 == 0
        self.n_mel_channels = n_mel_channels
        self.n_f0_dims = n_f0_dims  # >= 1 to trains with f0
        self.n_energy_avg_dims = n_energy_avg_dims  # >= 1 trains with energy
        self.decoder_use_partial_padding = kwargs['decoder_use_partial_padding']
        self.n_speaker_dim = n_speaker_dim
        assert self.n_speaker_dim % 2 == 0
        self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim)
        self.embedding = torch.nn.Embedding(n_text, n_text_dim)
        self.flows = torch.nn.ModuleList()
        self.encoder = Encoder(
            encoder_embedding_dim=n_text_dim, norm_fn=nn.InstanceNorm1d, lstm_norm_fn=text_encoder_lstm_norm
        )
        self.dummy_speaker_embedding = dummy_speaker_embedding
        self.learn_alignments = learn_alignments
        self.affine_activation = affine_activation
        self.include_modules = include_modules
        self.attn_use_CTC = bool(attn_use_CTC)
        self.use_context_lstm = bool(use_context_lstm)
        self.context_lstm_norm = context_lstm_norm
        self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy
        self.length_regulator = LengthRegulator()
        self.use_first_order_features = bool(use_first_order_features)
        self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias']
        self.ap_pred_log_f0 = ap_pred_log_f0
        self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias']
        if 'atn' in include_modules or 'dec' in include_modules:
            if self.learn_alignments:
                self.attention = ConvAttention(n_mel_channels, self.n_speaker_dim, n_text_dim)

            self.n_flows = n_flows
            self.n_group_size = n_group_size

            n_flowstep_cond_dims = self.n_speaker_dim + (n_text_dim + n_f0_dims + n_energy_avg_dims) * n_group_size

            if self.use_context_lstm:
                n_in_context_lstm = self.n_speaker_dim + n_text_dim * n_group_size
                n_context_lstm_hidden = int((self.n_speaker_dim + n_text_dim * n_group_size) / 2)

                if self.context_lstm_w_f0_and_energy:
                    n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim
                    n_in_context_lstm *= n_group_size
                    n_in_context_lstm += self.n_speaker_dim

                    n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim
                    n_context_hidden = n_context_hidden * n_group_size / 2
                    n_context_hidden = self.n_speaker_dim + n_context_hidden
                    n_context_hidden = int(n_context_hidden)

                    n_flowstep_cond_dims = self.n_speaker_dim + n_text_dim * n_group_size

                self.context_lstm = torch.nn.LSTM(
                    input_size=n_in_context_lstm,
                    hidden_size=n_context_lstm_hidden,
                    num_layers=1,
                    batch_first=True,
                    bidirectional=True,
                )

                if context_lstm_norm is not None:
                    if 'spectral' in context_lstm_norm:
                        print("Applying spectral norm to context encoder LSTM")
                        lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
                    elif 'weight' in context_lstm_norm:
                        print("Applying weight norm to context encoder LSTM")
                        lstm_norm_fn_pntr = torch.nn.utils.weight_norm

                    self.context_lstm = lstm_norm_fn_pntr(self.context_lstm, 'weight_hh_l0')
                    self.context_lstm = lstm_norm_fn_pntr(self.context_lstm, 'weight_hh_l0_reverse')

            if self.n_group_size > 1:
                self.unfold_params = {
                    'kernel_size': (n_group_size, 1),
                    'stride': n_group_size,
                    'padding': 0,
                    'dilation': 1,
                }
                self.unfold = nn.Unfold(**self.unfold_params)

            self.exit_steps = []
            self.n_early_size = n_early_size
            n_mel_channels = n_mel_channels * n_group_size

            for i in range(self.n_flows):
                if i > 0 and i % n_early_every == 0:  # early exitting
                    n_mel_channels -= self.n_early_size
                    self.exit_steps.append(i)

                self.flows.append(
                    FlowStep(
                        n_mel_channels,
                        n_flowstep_cond_dims,
                        n_conv_layers_per_step,
                        affine_model,
                        scaling_fn,
                        matrix_decomposition,
                        affine_activation=affine_activation,
                        use_partial_padding=self.decoder_use_partial_padding,
                    )
                )

        if 'dpm' in include_modules:
            dur_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
            self.dur_pred_layer = get_attribute_prediction_model(dur_model_config)

        self.use_unvoiced_bias = False
        self.use_vpred_module = False
        self.ap_use_voiced_embeddings = kwargs['ap_use_voiced_embeddings']

        if self.decoder_use_unvoiced_bias or self.ap_use_unvoiced_bias:
            assert unvoiced_bias_activation in {'relu', 'exp'}
            self.use_unvoiced_bias = True
            if unvoiced_bias_activation == 'relu':
                unvbias_nonlin = nn.ReLU()
            elif unvoiced_bias_activation == 'exp':
                unvbias_nonlin = ExponentialClass()
            else:
                exit(1)  # we won't reach here anyway due to the assertion
            self.unvoiced_bias_module = nn.Sequential(LinearNorm(n_text_dim, 1), unvbias_nonlin)

        # all situations in which the vpred module is necessary
        if self.ap_use_voiced_embeddings or self.use_unvoiced_bias or 'vpred' in include_modules:
            self.use_vpred_module = True

        if self.use_vpred_module:
            v_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
            self.v_pred_module = get_attribute_prediction_model(v_model_config)
            # 4 embeddings, first two are scales, second two are biases
            if self.ap_use_voiced_embeddings:
                self.v_embeddings = torch.nn.Embedding(4, n_text_dim)

        if 'apm' in include_modules:
            f0_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
            energy_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
            if self.use_first_order_features:
                f0_model_config['hparams']['n_in_dim'] = 2
                energy_model_config['hparams']['n_in_dim'] = 2
                if (
                    'spline_flow_params' in f0_model_config['hparams']
                    and f0_model_config['hparams']['spline_flow_params'] is not None
                ):
                    f0_model_config['hparams']['spline_flow_params']['n_in_channels'] = 2
                if (
                    'spline_flow_params' in energy_model_config['hparams']
                    and energy_model_config['hparams']['spline_flow_params'] is not None
                ):
                    energy_model_config['hparams']['spline_flow_params']['n_in_channels'] = 2
            else:
                if (
                    'spline_flow_params' in f0_model_config['hparams']
                    and f0_model_config['hparams']['spline_flow_params'] is not None
                ):
                    f0_model_config['hparams']['spline_flow_params']['n_in_channels'] = f0_model_config['hparams'][
                        'n_in_dim'
                    ]
                if (
                    'spline_flow_params' in energy_model_config['hparams']
                    and energy_model_config['hparams']['spline_flow_params'] is not None
                ):
                    energy_model_config['hparams']['spline_flow_params']['n_in_channels'] = energy_model_config[
                        'hparams'
                    ]['n_in_dim']

            self.f0_pred_module = get_attribute_prediction_model(f0_model_config)
            self.energy_pred_module = get_attribute_prediction_model(energy_model_config)

    def encode_speaker(self, spk_ids):
        spk_ids = spk_ids * 0 if self.dummy_speaker_embedding else spk_ids
        spk_vecs = self.speaker_embedding(spk_ids)
        return spk_vecs

    def encode_text(self, text, in_lens):
        # text_embeddings: b x len_text x n_text_dim
        text_embeddings = self.embedding(text).transpose(1, 2)
        # text_enc: b x n_text_dim x encoder_dim (512)
        if in_lens is None:
            text_enc = self.encoder.infer(text_embeddings).transpose(1, 2)
        else:
            text_enc = self.encoder(text_embeddings, in_lens).transpose(1, 2)

        return text_enc, text_embeddings

    def preprocess_context(self, context, speaker_vecs, out_lens=None, f0=None, energy_avg=None):

        if self.n_group_size > 1:
            context = self.unfold(context.unsqueeze(-1))
            # (todo): fix unfolding zero-padded values
            if f0 is not None:
                f0 = self.unfold(f0[:, None, :, None])
            if energy_avg is not None:
                energy_avg = self.unfold(energy_avg[:, None, :, None])
        speaker_vecs = speaker_vecs[..., None].expand(-1, -1, context.shape[2])
        context_w_spkvec = torch.cat((context, speaker_vecs), 1)

        if self.use_context_lstm:
            if self.context_lstm_w_f0_and_energy:
                if f0 is not None:
                    context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)

                if energy_avg is not None:
                    context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)

            unfolded_out_lens = (out_lens // self.n_group_size).long().cpu()
            unfolded_out_lens_packed = nn.utils.rnn.pack_padded_sequence(
                context_w_spkvec.transpose(1, 2), unfolded_out_lens, batch_first=True, enforce_sorted=False
            )
            self.context_lstm.flatten_parameters()
            context_lstm_packed_output, _ = self.context_lstm(unfolded_out_lens_packed)
            context_lstm_padded_output, _ = nn.utils.rnn.pad_packed_sequence(
                context_lstm_packed_output, batch_first=True
            )
            context_w_spkvec = context_lstm_padded_output.transpose(1, 2)

        if not self.context_lstm_w_f0_and_energy:
            if f0 is not None:
                context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)

            if energy_avg is not None:
                context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)

        return context_w_spkvec

    def fold(self, mel):
        """Inverse of the self.unfold(mel.unsqueeze(-1)) operation used for the
        grouping or "squeeze" operation on input

        Args:
            mel: B x C x T tensor of temporal data
        """
        mel = nn.functional.fold(mel, output_size=(mel.shape[2] * self.n_group_size, 1), **self.unfold_params).squeeze(
            -1
        )
        return mel

    def binarize_attention(self, attn, in_lens, out_lens):
        """For training purposes only. Binarizes attention with MAS. These will
        no longer recieve a gradient
        Args:
            attn: B x 1 x max_mel_len x max_text_len
        """
        b_size = attn.shape[0]
        with torch.no_grad():
            attn_cpu = attn.data.cpu().numpy()
            attn_out = torch.zeros_like(attn)
            for ind in range(b_size):
                hard_attn = mas(attn_cpu[ind, 0, : out_lens[ind], : in_lens[ind]])
                attn_out[ind, 0, : out_lens[ind], : in_lens[ind]] = torch.tensor(hard_attn, device=attn.get_device())
        return attn_out

    def get_first_order_features(self, feats, out_lens, dilation=1):
        """
        feats: b x max_length
        out_lens: b-dim
        """
        # add an extra column
        feats_extended_R = torch.cat((feats, torch.zeros_like(feats[:, 0:dilation])), dim=1)
        feats_extended_L = torch.cat((torch.zeros_like(feats[:, 0:dilation]), feats), dim=1)
        dfeats_R = feats_extended_R[:, dilation:] - feats
        dfeats_L = feats - feats_extended_L[:, 0:-dilation]

        return (dfeats_R + dfeats_L) * 0.5

    def apply_voice_mask_to_text(self, text_enc, voiced_mask):
        """
        text_enc: b x C x N
        voiced_mask: b x N
        """
        voiced_mask = voiced_mask.unsqueeze(1)
        voiced_embedding_s = self.v_embeddings.weight[0:1, :, None]
        unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None]
        voiced_embedding_b = self.v_embeddings.weight[2:3, :, None]
        unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None]
        scale = torch.sigmoid(voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask))
        bias = 0.1 * torch.tanh(voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask))
        return text_enc * scale + bias

    def forward(
        self,
        mel,
        speaker_ids,
        text,
        in_lens,
        out_lens,
        binarize_attention=False,
        attn_prior=None,
        f0=None,
        energy_avg=None,
        voiced_mask=None,
        p_voiced=None,
    ):
        speaker_vecs = self.encode_speaker(speaker_ids)
        text_enc, text_embeddings = self.encode_text(text, in_lens)

        log_s_list, log_det_W_list, z_mel = [], [], []
        attn = None
        attn_soft = None
        attn_hard = None
        if 'atn' in self.include_modules or 'dec' in self.include_modules:
            # make sure to do the alignments before folding
            attn_mask = get_mask_from_lengths(in_lens)[..., None] == 0
            # attn_mask shld be 1 for unsd t-steps in text_enc_w_spkvec tensor
            attn_soft, attn_logprob = self.attention(
                mel, text_embeddings, out_lens, attn_mask, key_lens=in_lens, attn_prior=attn_prior
            )

            if binarize_attention:
                attn = self.binarize_attention(attn_soft, in_lens, out_lens)
                attn_hard = attn
            else:
                attn = attn_soft

            context = torch.bmm(text_enc, attn.squeeze(1).transpose(1, 2))

        f0_bias = 0
        # unvoiced bias forward pass
        if self.use_unvoiced_bias:
            f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1))
            f0_bias = -f0_bias[..., 0]
            f0_bias = f0_bias * (~voiced_mask.bool()).float()

        # mel decoder forward pass
        if 'dec' in self.include_modules:
            if self.n_group_size > 1:
                # might truncate some frames at the end, but that's ok
                # sometimes referred to as the "squeeeze" operation
                # invert this by calling self.fold(mel_or_z)
                mel = self.unfold(mel.unsqueeze(-1))
            z_out = []
            # where context is folded
            # mask f0 in case values are interpolated
            if self.decoder_use_unvoiced_bias:
                context_w_spkvec = self.preprocess_context(
                    context, speaker_vecs, out_lens, f0 * voiced_mask + f0_bias, energy_avg
                )
            else:
                context_w_spkvec = self.preprocess_context(
                    context, speaker_vecs, out_lens, f0 * voiced_mask, energy_avg
                )

            log_s_list, log_det_W_list, z_out = [], [], []
            unfolded_seq_lens = out_lens // self.n_group_size
            for i, flow_step in enumerate(self.flows):
                if i in self.exit_steps:
                    z = mel[:, : self.n_early_size]
                    z_out.append(z)
                    mel = mel[:, self.n_early_size :]
                mel, log_det_W, log_s = flow_step(mel, context_w_spkvec, seq_lens=unfolded_seq_lens)
                log_s_list.append(log_s)
                log_det_W_list.append(log_det_W)

            z_out.append(mel)
            z_mel = torch.cat(z_out, 1)

        # duration predictor forward pass
        duration_model_outputs = None
        if 'dpm' in self.include_modules:
            if attn_hard is None:
                attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)

            # convert hard attention to durations
            attn_hard_reduced = attn_hard.sum(2)[:, 0, :]
            duration_model_outputs = self.dur_pred_layer(
                torch.detach(text_enc), torch.detach(speaker_vecs), torch.detach(attn_hard_reduced.float()), in_lens
            )

        # f0, energy, vpred predictors forward pass
        f0_model_outputs = None
        energy_model_outputs = None
        vpred_model_outputs = None
        if 'apm' in self.include_modules:
            if attn_hard is None:
                attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)

            # convert hard attention to durations
            if binarize_attention:
                text_enc_time_expanded = context.clone()
            else:
                text_enc_time_expanded = torch.bmm(text_enc, attn_hard.squeeze(1).transpose(1, 2))

            if self.use_vpred_module:
                # unvoiced bias requires  voiced mask prediction
                vpred_model_outputs = self.v_pred_module(
                    torch.detach(text_enc_time_expanded),
                    torch.detach(speaker_vecs),
                    torch.detach(voiced_mask),
                    out_lens,
                )

                # affine transform context using voiced mask
                if self.ap_use_voiced_embeddings:
                    text_enc_time_expanded = self.apply_voice_mask_to_text(text_enc_time_expanded, voiced_mask)
            if self.ap_use_unvoiced_bias:  # whether to use the unvoiced bias in the attribute predictor
                f0_target = torch.detach(f0 * voiced_mask + f0_bias)
            else:
                f0_target = torch.detach(f0)
            # fit to log f0 in f0 predictor
            f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()])
            f0_target = f0_target / 6  # scale to ~ [0, 1] in log space
            energy_avg = energy_avg * 2 - 1  # scale to ~ [-1, 1]

            if self.use_first_order_features:
                df0 = self.get_first_order_features(f0_target, out_lens)
                denergy_avg = self.get_first_order_features(energy_avg, out_lens)

                f0_voiced = torch.cat((f0_target[:, None], df0[:, None]), dim=1)
                energy_avg = torch.cat((energy_avg[:, None], denergy_avg[:, None]), dim=1)

                f0_voiced = f0_voiced * 3  # scale to ~ 1 std
                energy_avg = energy_avg * 3  # scale to ~ 1 std
            else:
                f0_voiced = f0_target * 2  # scale to ~ 1 std
                energy_avg = energy_avg * 1.4  # scale to ~ 1 std
            f0_model_outputs = self.f0_pred_module(
                text_enc_time_expanded, torch.detach(speaker_vecs), f0_voiced, out_lens
            )

            energy_model_outputs = self.energy_pred_module(
                text_enc_time_expanded, torch.detach(speaker_vecs), energy_avg, out_lens
            )

        outputs = {
            'z_mel': z_mel,
            'log_det_W_list': log_det_W_list,
            'log_s_list': log_s_list,
            'duration_model_outputs': duration_model_outputs,
            'f0_model_outputs': f0_model_outputs,
            'energy_model_outputs': energy_model_outputs,
            'vpred_model_outputs': vpred_model_outputs,
            'attn_soft': attn_soft,
            'attn': attn,
            'text_embeddings': text_embeddings,
            'attn_logprob': attn_logprob,
        }

        return outputs

    def infer(
        self,
        speaker_id,
        text,
        sigma,
        sigma_txt=0.8,
        sigma_f0=0.8,
        sigma_energy=0.8,
        token_dur_scaling=1.0,
        token_duration_max=100,
        dur=None,
        f0=None,
        energy_avg=None,
        voiced_mask=None,
    ):

        n_tokens = text.shape[1]
        spk_vec = self.encode_speaker(speaker_id)
        txt_enc, txt_emb = self.encode_text(text, None)

        if dur is None:
            # get token durations
            z_dur = torch.cuda.FloatTensor(1, 1, n_tokens)
            z_dur = z_dur.normal_() * sigma_txt

            dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec)
            if dur.shape[-1] < txt_enc.shape[-1]:
                to_pad = txt_enc.shape[-1] - dur.shape[2]
                pad_fn = nn.ReplicationPad1d((0, to_pad))
                dur = pad_fn(dur)
            dur = dur[:, 0]
            dur = dur.clamp(0, token_duration_max)
            dur = dur * token_dur_scaling if token_dur_scaling > 0 else dur
            dur = (dur + 0.5).floor().int()

        n_frames = dur.sum().item()
        out_lens = torch.LongTensor([n_frames]).to(txt_enc.device)

        # get attributes f0, energy, vpred, etc)
        txt_enc_time_expanded = self.length_regulator(txt_enc.transpose(1, 2), dur).transpose(1, 2)

        if voiced_mask is None:
            if self.use_vpred_module:
                # get logits
                voiced_mask = self.v_pred_module.infer(None, txt_enc_time_expanded, spk_vec)
                voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5
                voiced_mask = voiced_mask.float()

        ap_txt_enc_time_expanded = txt_enc_time_expanded
        # voice mask augmentation only used for attribute prediction
        if self.ap_use_voiced_embeddings:
            ap_txt_enc_time_expanded = self.apply_voice_mask_to_text(txt_enc_time_expanded, voiced_mask)

        f0_bias = 0
        # unvoiced bias forward pass
        if self.use_unvoiced_bias:
            f0_bias = self.unvoiced_bias_module(txt_enc_time_expanded.permute(0, 2, 1))
            f0_bias = -f0_bias[..., 0]
            f0_bias = f0_bias * (~voiced_mask.bool()).float()

        if f0 is None:
            n_f0_feature_channels = 2 if self.use_first_order_features else 1
            z_f0 = torch.cuda.FloatTensor(1, n_f0_feature_channels, n_frames).normal_() * sigma_f0
            f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec, voiced_mask, out_lens)[:, 0]

        if energy_avg is None:
            n_energy_feature_channels = 2 if self.use_first_order_features else 1
            z_energy_avg = torch.cuda.FloatTensor(1, n_energy_feature_channels, n_frames).normal_() * sigma_energy
            energy_avg = self.infer_energy(z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens)[:, 0]

        # replication pad, because ungrouping with different group sizes
        # may lead to mismatched lengths
        if energy_avg.shape[1] < out_lens[0]:
            to_pad = out_lens[0] - energy_avg.shape[1]
            pad_fn = nn.ReplicationPad1d((0, to_pad))
            # f0 = pad_fn(f0[None])[0]
            energy_avg = pad_fn(energy_avg[None])[0]
        if f0.shape[1] < out_lens[0]:
            to_pad = out_lens[0] - f0.shape[1]
            pad_fn = nn.ReplicationPad1d((0, to_pad))
            f0 = pad_fn(f0[None])[0]

        if self.decoder_use_unvoiced_bias:
            context_w_spkvec = self.preprocess_context(
                txt_enc_time_expanded, spk_vec, out_lens, f0 * voiced_mask + f0_bias, energy_avg
            )

        else:
            context_w_spkvec = self.preprocess_context(
                txt_enc_time_expanded, spk_vec, out_lens, f0 * voiced_mask, energy_avg
            )

        residual = torch.cuda.FloatTensor(1, 80 * self.n_group_size, n_frames // self.n_group_size)
        residual = residual.normal_() * sigma

        # map from z sample to data
        exit_steps_stack = self.exit_steps.copy()
        mel = residual[:, len(exit_steps_stack) * self.n_early_size :]
        remaining_residual = residual[:, : len(exit_steps_stack) * self.n_early_size]
        unfolded_seq_lens = out_lens // self.n_group_size
        for i, flow_step in enumerate(reversed(self.flows)):
            curr_step = len(self.flows) - i - 1
            mel = flow_step(mel, context_w_spkvec, inverse=True, seq_lens=unfolded_seq_lens)
            if len(exit_steps_stack) > 0 and curr_step == exit_steps_stack[-1]:
                # concatenate the next chunk of z
                exit_steps_stack.pop()
                residual_to_add = remaining_residual[:, len(exit_steps_stack) * self.n_early_size :]
                remaining_residual = remaining_residual[:, : len(exit_steps_stack) * self.n_early_size]
                mel = torch.cat((residual_to_add, mel), 1)

        if self.n_group_size > 1:
            mel = self.fold(mel)

        return {'mel': mel, 'dur': dur, 'f0': f0, 'energy_avg': energy_avg}

    def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None):
        print("txt_enc_time_expanded", txt_enc_time_expanded.size())
        print("spk_vec", spk_vec.size())
        f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens)

        if voiced_mask is not None and len(voiced_mask.shape) == 2:
            voiced_mask = voiced_mask[:, None]
        # constants
        if self.ap_pred_log_f0:
            if self.use_first_order_features:
                f0 = f0[:, 0:1, :] / 3
            else:
                f0 = f0 / 2
            f0 = f0 * 6
        else:
            f0 = f0 / 6
            f0 = f0 / 640

        if voiced_mask is None:
            voiced_mask = f0 > 0.0
        else:
            voiced_mask = voiced_mask.bool()
        # due to grouping, f0 might be 1 frame short
        voiced_mask = voiced_mask[:, :, : f0.shape[-1]]
        if self.ap_pred_log_f0:
            # if variable is set, decoder sees linear f0
            # mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool()
            f0[voiced_mask] = torch.exp(f0[voiced_mask])
        f0[~voiced_mask] = 0.0
        return f0

    def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens):
        energy = self.energy_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens)

        # magic constants
        if self.use_first_order_features:
            energy = energy / 3
        else:
            energy = energy / 1.4
        energy = (energy + 1) / 2
        return energy

    def remove_norms(self):
        """Removes spectral and weightnorms from model. Call before inference
        """
        for name, module in self.named_modules():
            try:
                nn.utils.remove_spectral_norm(module, name='weight_hh_l0')
                print("Removed spectral norm from {}".format(name))
            except:
                pass
            try:
                nn.utils.remove_spectral_norm(module, name='weight_hh_l0_reverse')
                print("Removed spectral norm from {}".format(name))
            except:
                pass
            try:
                nn.utils.remove_weight_norm(module)
                print("Removed wnorm from {}".format(name))
            except:
                pass
